"""Contains tests for networking.py and app.py"""
import functools
import os
import tempfile
import time
from contextlib import asynccontextmanager, closing
from typing import Dict
from unittest.mock import patch

import gradio_client as grc
import numpy as np
import pandas as pd
import pytest
import requests
import starlette.routing
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from gradio_client import media_data

import gradio as gr
from gradio import (
    Blocks,
    Button,
    Interface,
    Number,
    Textbox,
    close_all,
    routes,
    wasm_utils,
)
from gradio.route_utils import (
    FnIndexInferError,
    compare_passwords_securely,
    get_root_url,
    starts_with_protocol,
)


@pytest.fixture()
def test_client():
    io = Interface(lambda x: x + x, "text", "text")
    app, _, _ = io.launch(prevent_thread_lock=True)
    test_client = TestClient(app)
    yield test_client
    io.close()
    close_all()


class TestRoutes:
    def test_get_main_route(self, test_client):
        response = test_client.get("/")
        assert response.status_code == 200

    def test_static_files_served_safely(self, test_client):
        # Make sure things outside the static folder are not accessible
        response = test_client.get(r"/static/..%2findex.html")
        assert response.status_code == 403
        response = test_client.get(r"/static/..%2f..%2fapi_docs.html")
        assert response.status_code == 403

    def test_get_config_route(self, test_client):
        response = test_client.get("/config/")
        assert response.status_code == 200

    def test_favicon_route(self, test_client):
        response = test_client.get("/favicon.ico")
        assert response.status_code == 200

    def test_upload_path(self, test_client):
        with open("test/test_files/alphabet.txt", "rb") as f:
            response = test_client.post("/upload", files={"files": f})
        assert response.status_code == 200
        file = response.json()[0]
        assert "alphabet" in file
        assert file.endswith(".txt")
        with open(file, "rb") as saved_file:
            assert saved_file.read() == b"abcdefghijklmnopqrstuvwxyz"

    def test_custom_upload_path(self, gradio_temp_dir):
        io = Interface(lambda x: x + x, "text", "text")
        app, _, _ = io.launch(prevent_thread_lock=True)
        test_client = TestClient(app)
        with open("test/test_files/alphabet.txt", "rb") as f:
            response = test_client.post("/upload", files={"files": f})
        assert response.status_code == 200
        file = response.json()[0]
        assert "alphabet" in file
        assert file.startswith(str(gradio_temp_dir))
        assert file.endswith(".txt")
        with open(file, "rb") as saved_file:
            assert saved_file.read() == b"abcdefghijklmnopqrstuvwxyz"

    def test_predict_route(self, test_client):
        response = test_client.post(
            "/api/predict/", json={"data": ["test"], "fn_index": 0}
        )
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["testtest"]

    def test_named_predict_route(self):
        with Blocks() as demo:
            i = Textbox()
            o = Textbox()
            i.change(lambda x: f"{x}1", i, o, api_name="p")
            i.change(lambda x: f"{x}2", i, o, api_name="q")

        app, _, _ = demo.launch(prevent_thread_lock=True)
        client = TestClient(app)
        response = client.post("/api/p/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test1"]

        response = client.post("/api/q/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test2"]

    def test_same_named_predict_route(self):
        with Blocks() as demo:
            i = Textbox()
            o = Textbox()
            i.change(lambda x: f"{x}0", i, o, api_name="p")
            i.change(lambda x: f"{x}1", i, o, api_name="p")

        app, _, _ = demo.launch(prevent_thread_lock=True)
        client = TestClient(app)
        response = client.post("/api/p/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test0"]

        response = client.post("/api/p_1/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test1"]

    def test_multiple_renamed(self):
        with Blocks() as demo:
            i = Textbox()
            o = Textbox()
            i.change(lambda x: f"{x}0", i, o, api_name="p")
            i.change(lambda x: f"{x}1", i, o, api_name="p")
            i.change(lambda x: f"{x}2", i, o, api_name="p_1")

        app, _, _ = demo.launch(prevent_thread_lock=True)
        client = TestClient(app)
        response = client.post("/api/p/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test0"]

        response = client.post("/api/p_1/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test1"]

        response = client.post("/api/p_1_1/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test2"]

    def test_predict_route_without_fn_index(self, test_client):
        response = test_client.post("/api/predict/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["testtest"]

    def test_predict_route_batching(self):
        def batch_fn(x):
            results = []
            for word in x:
                results.append(f"Hello {word}")
            return (results,)

        with gr.Blocks() as demo:
            text = gr.Textbox()
            btn = gr.Button()
            btn.click(batch_fn, inputs=text, outputs=text, batch=True, api_name="pred")

        demo.queue(api_open=True)
        app, _, _ = demo.launch(prevent_thread_lock=True)
        client = TestClient(app)
        response = client.post("/api/pred/", json={"data": ["test"]})
        output = dict(response.json())
        assert output["data"] == ["Hello test"]

        app, _, _ = demo.launch(prevent_thread_lock=True)
        client = TestClient(app)
        response = client.post(
            "/api/pred/", json={"data": [["test", "test2"]], "batched": True}
        )
        output = dict(response.json())
        assert output["data"] == [["Hello test", "Hello test2"]]

    def test_state(self):
        def predict(input, history):
            if history is None:
                history = ""
            history += input
            return history, history

        io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
        app, _, _ = io.launch(prevent_thread_lock=True)
        client = TestClient(app)
        response = client.post(
            "/api/predict/",
            json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
        )
        output = dict(response.json())
        assert output["data"] == ["test", None]
        response = client.post(
            "/api/predict/",
            json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
        )
        output = dict(response.json())
        assert output["data"] == ["testtest", None]

    def test_get_allowed_paths(self):
        allowed_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
        allowed_file.write(media_data.BASE64_IMAGE)
        allowed_file.flush()

        io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
        app, _, _ = io.launch(prevent_thread_lock=True)
        client = TestClient(app)
        file_response = client.get(f"/file={allowed_file.name}")
        assert file_response.status_code == 403
        io.close()

        io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
        app, _, _ = io.launch(
            prevent_thread_lock=True,
            allowed_paths=[os.path.dirname(allowed_file.name)],
        )
        client = TestClient(app)
        file_response = client.get(f"/file={allowed_file.name}")
        assert file_response.status_code == 200
        assert len(file_response.text) == len(media_data.BASE64_IMAGE)
        io.close()

        io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
        app, _, _ = io.launch(
            prevent_thread_lock=True,
            allowed_paths=[os.path.abspath(allowed_file.name)],
        )
        client = TestClient(app)
        file_response = client.get(f"/file={allowed_file.name}")
        assert file_response.status_code == 200
        assert len(file_response.text) == len(media_data.BASE64_IMAGE)
        io.close()

    def test_allowed_and_blocked_paths(self):
        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
            io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
            app, _, _ = io.launch(
                prevent_thread_lock=True,
                allowed_paths=[os.path.dirname(tmp_file.name)],
            )
            client = TestClient(app)
            file_response = client.get(f"/file={tmp_file.name}")
            assert file_response.status_code == 200
        io.close()
        os.remove(tmp_file.name)

        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
            io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
            app, _, _ = io.launch(
                prevent_thread_lock=True,
                allowed_paths=[os.path.dirname(tmp_file.name)],
                blocked_paths=[os.path.dirname(tmp_file.name)],
            )
            client = TestClient(app)
            file_response = client.get(f"/file={tmp_file.name}")
            assert file_response.status_code == 403
        io.close()
        os.remove(tmp_file.name)

    def test_get_file_created_by_app(self, test_client):
        app, _, _ = gr.Interface(lambda s: s.name, gr.File(), gr.File()).launch(
            prevent_thread_lock=True
        )
        client = TestClient(app)
        with open("test/test_files/alphabet.txt", "rb") as f:
            file_response = test_client.post("/upload", files={"files": f})
        response = client.post(
            "/api/predict/",
            json={
                "data": [
                    {
                        "path": file_response.json()[0],
                        "size": os.path.getsize("test/test_files/alphabet.txt"),
                    }
                ],
                "fn_index": 0,
                "session_hash": "_",
            },
        ).json()
        created_file = response["data"][0]["path"]
        file_response = client.get(f"/file={created_file}")
        assert file_response.is_success

        backwards_compatible_file_response = client.get(f"/file/{created_file}")
        assert backwards_compatible_file_response.is_success

        file_response_with_full_range = client.get(
            f"/file={created_file}", headers={"Range": "bytes=0-"}
        )
        assert file_response_with_full_range.is_success
        assert file_response.text == file_response_with_full_range.text

        file_response_with_partial_range = client.get(
            f"/file={created_file}", headers={"Range": "bytes=0-10"}
        )
        assert file_response_with_partial_range.is_success
        assert len(file_response_with_partial_range.text) == 11

    def test_mount_gradio_app(self):
        app = FastAPI()

        demo = gr.Interface(
            lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
        ).queue()
        demo1 = gr.Interface(
            lambda s: f"Hello from py, {s}!", "textbox", "textbox"
        ).queue()

        app = gr.mount_gradio_app(app, demo, path="/ps")
        app = gr.mount_gradio_app(app, demo1, path="/py")

        # Use context manager to trigger start up events
        with TestClient(app) as client:
            assert client.get("/ps").is_success
            assert client.get("/py").is_success

    def test_mount_gradio_app_with_app_kwargs(self):
        app = FastAPI()

        demo = gr.Interface(lambda s: f"You said {s}!", "textbox", "textbox").queue()

        app = gr.mount_gradio_app(
            app, demo, path="/echo", app_kwargs={"docs_url": "/docs-custom"}
        )

        # Use context manager to trigger start up events
        with TestClient(app) as client:
            assert client.get("/echo/docs-custom").is_success

    def test_mount_gradio_app_with_lifespan(self):
        @asynccontextmanager
        async def empty_lifespan(app: FastAPI):
            yield

        app = FastAPI(lifespan=empty_lifespan)

        demo = gr.Interface(
            lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
        ).queue()
        demo1 = gr.Interface(
            lambda s: f"Hello from py, {s}!", "textbox", "textbox"
        ).queue()

        app = gr.mount_gradio_app(app, demo, path="/ps")
        app = gr.mount_gradio_app(app, demo1, path="/py")

        # Use context manager to trigger start up events
        with TestClient(app) as client:
            assert client.get("/ps").is_success
            assert client.get("/py").is_success

    def test_mount_gradio_app_with_startup(self):
        app = FastAPI()

        @app.on_event("startup")
        async def empty_startup():
            return

        demo = gr.Interface(
            lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
        ).queue()
        demo1 = gr.Interface(
            lambda s: f"Hello from py, {s}!", "textbox", "textbox"
        ).queue()

        app = gr.mount_gradio_app(app, demo, path="/ps")
        app = gr.mount_gradio_app(app, demo1, path="/py")

        # Use context manager to trigger start up events
        with TestClient(app) as client:
            assert client.get("/ps").is_success
            assert client.get("/py").is_success

    def test_mount_gradio_app_with_auth_dependency(self):
        app = FastAPI()

        def get_user(request: Request):
            return request.headers.get("user")

        demo = gr.Interface(lambda s: f"Hello from ps, {s}!", "textbox", "textbox")

        app = gr.mount_gradio_app(app, demo, path="/demo", auth_dependency=get_user)

        with TestClient(app) as client:
            assert client.get("/demo", headers={"user": "abubakar"}).is_success
            assert not client.get("/demo").is_success

    def test_static_file_missing(self, test_client):
        response = test_client.get(r"/static/not-here.js")
        assert response.status_code == 404

    def test_asset_file_missing(self, test_client):
        response = test_client.get(r"/assets/not-here.js")
        assert response.status_code == 404

    def test_cannot_access_files_in_working_directory(self, test_client):
        response = test_client.get(r"/file=not-here.js")
        assert response.status_code == 403
        response = test_client.get(r"/file=subdir/.env")
        assert response.status_code == 403

    def test_cannot_access_directories_in_working_directory(self, test_client):
        response = test_client.get(r"/file=gradio")
        assert response.status_code == 403

    def test_block_protocols_that_expose_windows_credentials(self, test_client):
        response = test_client.get(r"/file=//11.0.225.200/share")
        assert response.status_code == 403

    def test_do_not_expose_existence_of_files_outside_working_directory(
        self, test_client
    ):
        response = test_client.get(r"/file=../fake-file-that-does-not-exist.js")
        assert response.status_code == 403  # not a 404

    def test_proxy_route_is_restricted_to_load_urls(self):
        gr.context.Context.hf_token = "abcdef"
        app = routes.App()
        interface = gr.Interface(lambda x: x, "text", "text")
        app.configure_app(interface)
        with pytest.raises(PermissionError):
            app.build_proxy_request(
                "https://gradio-tests-test-loading-examples-private.hf.space/file=Bunny.obj"
            )
        with pytest.raises(PermissionError):
            app.build_proxy_request("https://google.com")
        interface.proxy_urls = {
            "https://gradio-tests-test-loading-examples-private.hf.space"
        }
        app.build_proxy_request(
            "https://gradio-tests-test-loading-examples-private.hf.space/file=Bunny.obj"
        )

    def test_proxy_does_not_leak_hf_token_externally(self):
        gr.context.Context.hf_token = "abcdef"
        app = routes.App()
        interface = gr.Interface(lambda x: x, "text", "text")
        interface.proxy_urls = {
            "https://gradio-tests-test-loading-examples-private.hf.space",
            "https://google.com",
        }
        app.configure_app(interface)
        r = app.build_proxy_request(
            "https://gradio-tests-test-loading-examples-private.hf.space/file=Bunny.obj"
        )
        assert "authorization" in dict(r.headers)
        r = app.build_proxy_request("https://google.com")
        assert "authorization" not in dict(r.headers)

    def test_can_get_config_that_includes_non_pickle_able_objects(self):
        my_dict = {"a": 1, "b": 2, "c": 3}
        with Blocks() as demo:
            gr.JSON(my_dict.keys())

        app, _, _ = demo.launch(prevent_thread_lock=True)
        client = TestClient(app)
        response = client.get("/")
        assert response.is_success
        response = client.get("/config/")
        assert response.is_success

    def test_cors_restrictions(self):
        io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
        app, _, _ = io.launch(prevent_thread_lock=True)
        client = TestClient(app)
        custom_headers = {
            "host": "localhost:7860",
            "origin": "https://example.com",
        }
        file_response = client.get("/config", headers=custom_headers)
        assert "access-control-allow-origin" not in file_response.headers
        custom_headers = {
            "host": "localhost:7860",
            "origin": "127.0.0.1",
        }
        file_response = client.get("/config", headers=custom_headers)
        assert file_response.headers["access-control-allow-origin"] == "127.0.0.1"
        io.close()

    def test_delete_cache(self, connect, gradio_temp_dir, capsys):
        def check_num_files_exist(blocks: Blocks):
            num_files = 0
            for temp_file_set in blocks.temp_file_sets:
                for temp_file in temp_file_set:
                    if os.path.exists(temp_file):
                        num_files += 1
            return num_files

        demo = gr.Interface(lambda s: s, gr.Textbox(), gr.File(), delete_cache=None)
        with connect(demo) as client:
            client.predict("test/test_files/cheetah1.jpg")
        assert check_num_files_exist(demo) == 1

        demo_delete = gr.Interface(
            lambda s: s, gr.Textbox(), gr.File(), delete_cache=(60, 30)
        )
        with connect(demo_delete) as client:
            client.predict("test/test_files/alphabet.txt")
            client.predict("test/test_files/bus.png")
            assert check_num_files_exist(demo_delete) == 2
        assert check_num_files_exist(demo_delete) == 0
        assert check_num_files_exist(demo) == 1

        @asynccontextmanager
        async def mylifespan(app: FastAPI):
            print("IN CUSTOM LIFESPAN")
            yield
            print("AFTER CUSTOM LIFESPAN")

        demo_custom_lifespan = gr.Interface(
            lambda s: s, gr.Textbox(), gr.File(), delete_cache=(5, 1)
        )

        with connect(
            demo_custom_lifespan, app_kwargs={"lifespan": mylifespan}
        ) as client:
            client.predict("test/test_files/alphabet.txt")
        assert check_num_files_exist(demo_custom_lifespan) == 0
        captured = capsys.readouterr()
        assert "IN CUSTOM LIFESPAN" in captured.out
        assert "AFTER CUSTOM LIFESPAN" in captured.out


class TestApp:
    def test_create_app(self):
        app = routes.App.create_app(Interface(lambda x: x, "text", "text"))
        assert isinstance(app, FastAPI)


class TestAuthenticatedRoutes:
    def test_post_login(self):
        io = Interface(lambda x: x, "text", "text")
        app, _, _ = io.launch(
            auth=("test", "correct_password"),
            prevent_thread_lock=True,
        )
        client = TestClient(app)

        response = client.post(
            "/login",
            data={"username": "test", "password": "correct_password"},
        )
        assert response.status_code == 200

        response = client.post(
            "/login",
            data={"username": "test", "password": "incorrect_password"},
        )
        assert response.status_code == 400

        client.post(
            "/login",
            data={"username": "test", "password": "correct_password"},
        )
        response = client.post(
            "/login",
            data={"username": " test ", "password": "correct_password"},
        )
        assert response.status_code == 200

    def test_logout(self):
        io = Interface(lambda x: x, "text", "text")
        app, _, _ = io.launch(
            auth=("test", "correct_password"),
            prevent_thread_lock=True,
        )
        client = TestClient(app)

        client.post(
            "/login",
            data={"username": "test", "password": "correct_password"},
        )

        response = client.post(
            "/run/predict",
            json={"data": ["test"]},
        )
        assert response.status_code == 200

        response = client.get("/logout")

        response = client.post(
            "/run/predict",
            json={"data": ["test"]},
        )
        assert response.status_code == 401


class TestQueueRoutes:
    @pytest.mark.asyncio
    async def test_queue_join_routes_sets_app_if_none_set(self):
        io = Interface(lambda x: x, "text", "text").queue()
        io.launch(prevent_thread_lock=True)
        io._queue.server_path = None

        client = grc.Client(io.local_url)
        client.predict("test")

        assert io._queue.server_app == io.server_app


class TestDevMode:
    def test_mount_gradio_app_set_dev_mode_false(self):
        app = FastAPI()

        @app.get("/")
        def read_main():
            return {"message": "Hello!"}

        with gr.Blocks() as blocks:
            gr.Textbox("Hello from gradio!")

        app = routes.mount_gradio_app(app, blocks, path="/gradio")
        gradio_fast_api = next(
            route for route in app.routes if isinstance(route, starlette.routing.Mount)
        )
        assert not gradio_fast_api.app.blocks.dev_mode


class TestPassingRequest:
    def test_request_included_with_interface(self):
        def identity(name, request: gr.Request):
            assert isinstance(request.client.host, str)
            return name

        app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
            prevent_thread_lock=True,
        )
        client = TestClient(app)

        response = client.post("/api/predict/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test"]

    def test_request_included_with_chat_interface(self):
        def identity(x, y, request: gr.Request):
            assert isinstance(request.client.host, str)
            return x

        app, _, _ = gr.ChatInterface(identity).launch(
            prevent_thread_lock=True,
        )
        client = TestClient(app)

        response = client.post("/api/chat/", json={"data": ["test", None]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test", None]

    def test_request_included_with_chat_interface_when_streaming(self):
        def identity(x, y, request: gr.Request):
            assert isinstance(request.client.host, str)
            for i in range(len(x)):
                yield x[: i + 1]

        app, _, _ = (
            gr.ChatInterface(identity)
            .queue(api_open=True)
            .launch(
                prevent_thread_lock=True,
            )
        )
        client = TestClient(app)

        response = client.post("/api/chat/", json={"data": ["test", None]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["t", None]

    def test_request_get_headers(self):
        def identity(name, request: gr.Request):
            assert isinstance(request.headers["user-agent"], str)
            assert isinstance(request.headers.items(), list)
            assert isinstance(request.headers.keys(), list)
            assert isinstance(request.headers.values(), list)
            assert isinstance(dict(request.headers), dict)
            user_agent = request.headers["user-agent"]
            assert "testclient" in user_agent
            return name

        app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
            prevent_thread_lock=True,
        )
        client = TestClient(app)

        response = client.post("/api/predict/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test"]

    def test_request_includes_username_as_none_if_no_auth(self):
        def identity(name, request: gr.Request):
            assert request.username is None
            return name

        app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
            prevent_thread_lock=True,
        )
        client = TestClient(app)

        response = client.post("/api/predict/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test"]

    def test_request_includes_username_with_auth(self):
        def identity(name, request: gr.Request):
            assert request.username == "admin"
            return name

        app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
            prevent_thread_lock=True, auth=("admin", "password")
        )
        client = TestClient(app)

        client.post(
            "/login",
            data={"username": "admin", "password": "password"},
        )
        response = client.post("/api/predict/", json={"data": ["test"]})
        assert response.status_code == 200
        output = dict(response.json())
        assert output["data"] == ["test"]


def test_predict_route_is_blocked_if_api_open_false():
    io = Interface(lambda x: x, "text", "text", examples=[["freddy"]]).queue(
        api_open=False
    )
    app, _, _ = io.launch(prevent_thread_lock=True)
    assert io.show_api
    client = TestClient(app)
    result = client.post(
        "/api/predict", json={"fn_index": 0, "data": [5], "session_hash": "foo"}
    )
    assert result.status_code == 404


def test_predict_route_not_blocked_if_queue_disabled():
    with Blocks() as demo:
        input = Textbox()
        output = Textbox()
        number = Number()
        button = Button()
        button.click(
            lambda x: f"Hello, {x}!", input, output, queue=False, api_name="not_blocked"
        )
        button.click(lambda: 42, None, number, queue=True, api_name="blocked")
    app, _, _ = demo.queue(api_open=False).launch(
        prevent_thread_lock=True, show_api=True
    )
    assert demo.show_api
    client = TestClient(app)

    result = client.post("/api/blocked", json={"data": [], "session_hash": "foo"})
    assert result.status_code == 404
    result = client.post(
        "/api/not_blocked", json={"data": ["freddy"], "session_hash": "foo"}
    )
    assert result.status_code == 200
    assert result.json()["data"] == ["Hello, freddy!"]


def test_predict_route_not_blocked_if_routes_open():
    with Blocks() as demo:
        input = Textbox()
        output = Textbox()
        button = Button()
        button.click(
            lambda x: f"Hello, {x}!", input, output, queue=True, api_name="not_blocked"
        )
    app, _, _ = demo.queue(api_open=True).launch(
        prevent_thread_lock=True, show_api=False
    )
    assert not demo.show_api
    client = TestClient(app)

    result = client.post(
        "/api/not_blocked", json={"data": ["freddy"], "session_hash": "foo"}
    )
    assert result.status_code == 200
    assert result.json()["data"] == ["Hello, freddy!"]

    demo.close()
    demo.queue(api_open=False).launch(prevent_thread_lock=True, show_api=False)
    assert not demo.show_api


def test_show_api_queue_not_enabled():
    io = Interface(lambda x: x, "text", "text", examples=[["freddy"]])
    app, _, _ = io.launch(prevent_thread_lock=True)
    assert io.show_api
    io.close()
    io.launch(prevent_thread_lock=True, show_api=False)
    assert not io.show_api


def test_orjson_serialization():
    df = pd.DataFrame(
        {
            "date_1": pd.date_range("2021-01-01", periods=2),
            "date_2": pd.date_range("2022-02-15", periods=2).strftime("%B %d, %Y, %r"),
            "number": np.array([0.2233, 0.57281]),
            "number_2": np.array([84, 23]).astype(np.int64),
            "bool": [True, False],
            "markdown": ["# Hello", "# Goodbye"],
        }
    )

    with gr.Blocks() as demo:
        gr.DataFrame(df)
    app, _, _ = demo.launch(prevent_thread_lock=True)
    test_client = TestClient(app)
    response = test_client.get("/")
    assert response.status_code == 200
    demo.close()


def test_api_name_set_for_all_events(connect):
    with gr.Blocks() as demo:
        i = Textbox()
        o = Textbox()
        btn = Button()
        btn1 = Button()
        btn2 = Button()
        btn3 = Button()
        btn4 = Button()
        btn5 = Button()
        btn6 = Button()
        btn7 = Button()
        btn8 = Button()

        def greet(i):
            return "Hello " + i

        def goodbye(i):
            return "Goodbye " + i

        def greet_me(i):
            return "Hello"

        def say_goodbye(i):
            return "Goodbye"

        say_goodbye.__name__ = "Say_$$_goodbye"

        # Otherwise changed by ruff
        foo = lambda s: s  # noqa

        def foo2(s):
            return s + " foo"

        foo2.__name__ = "foo-2"

        class Callable:
            def __call__(self, a) -> str:
                return "From __call__"

        def from_partial(a, b):
            return b + a

        part = functools.partial(from_partial, b="From partial: ")

        btn.click(greet, i, o)
        btn1.click(goodbye, i, o)
        btn2.click(greet_me, i, o)
        btn3.click(say_goodbye, i, o)
        btn4.click(None, i, o)
        btn5.click(foo, i, o)
        btn6.click(foo2, i, o)
        btn7.click(Callable(), i, o)
        btn8.click(part, i, o)

    with closing(demo) as io:
        app, _, _ = io.launch(prevent_thread_lock=True)
        client = TestClient(app)
        assert client.post(
            "/api/greet", json={"data": ["freddy"], "session_hash": "foo"}
        ).json()["data"] == ["Hello freddy"]
        assert client.post(
            "/api/goodbye", json={"data": ["freddy"], "session_hash": "foo"}
        ).json()["data"] == ["Goodbye freddy"]
        assert client.post(
            "/api/greet_me", json={"data": ["freddy"], "session_hash": "foo"}
        ).json()["data"] == ["Hello"]
        assert client.post(
            "/api/Say__goodbye", json={"data": ["freddy"], "session_hash": "foo"}
        ).json()["data"] == ["Goodbye"]
        assert client.post(
            "/api/lambda", json={"data": ["freddy"], "session_hash": "foo"}
        ).json()["data"] == ["freddy"]
        assert client.post(
            "/api/foo-2", json={"data": ["freddy"], "session_hash": "foo"}
        ).json()["data"] == ["freddy foo"]
        assert client.post(
            "/api/Callable", json={"data": ["freddy"], "session_hash": "foo"}
        ).json()["data"] == ["From __call__"]
        assert client.post(
            "/api/partial", json={"data": ["freddy"], "session_hash": "foo"}
        ).json()["data"] == ["From partial: freddy"]
        with pytest.raises(FnIndexInferError):
            client.post(
                "/api/Say_goodbye", json={"data": ["freddy"], "session_hash": "foo"}
            )

    with connect(demo) as client:
        assert client.predict("freddy", api_name="/greet") == "Hello freddy"
        assert client.predict("freddy", api_name="/goodbye") == "Goodbye freddy"
        assert client.predict("freddy", api_name="/greet_me") == "Hello"
        assert client.predict("freddy", api_name="/Say__goodbye") == "Goodbye"


class TestShowAPI:
    @patch.object(wasm_utils, "IS_WASM", True)
    def test_show_api_false_when_is_wasm_true(self):
        interface = Interface(lambda x: x, "text", "text", examples=[["hannah"]])
        assert (
            interface.show_api is False
        ), "show_api should be False when IS_WASM is True"

    @patch.object(wasm_utils, "IS_WASM", False)
    def test_show_api_true_when_is_wasm_false(self):
        interface = Interface(lambda x: x, "text", "text", examples=[["hannah"]])
        assert (
            interface.show_api is True
        ), "show_api should be True when IS_WASM is False"


def test_component_server_endpoints(connect):
    here = os.path.dirname(os.path.abspath(__file__))
    with gr.Blocks() as demo:
        file_explorer = gr.FileExplorer(root=here)

    with closing(demo) as io:
        app, _, _ = io.launch(prevent_thread_lock=True)
        client = TestClient(app)
        success_req = client.post(
            "/component_server/",
            json={
                "session_hash": "123",
                "component_id": file_explorer._id,
                "fn_name": "ls",
                "data": None,
            },
        )
        assert success_req.status_code == 200
        assert len(success_req.json()) > 0
        fail_req = client.post(
            "/component_server/",
            json={
                "session_hash": "123",
                "component_id": file_explorer._id,
                "fn_name": "preprocess",
                "data": None,
            },
        )
        assert fail_req.status_code == 404


@pytest.mark.parametrize(
    "request_url, route_path, root_path, expected_root_url",
    [
        ("http://localhost:7860/", "/", None, "http://localhost:7860"),
        (
            "http://localhost:7860/demo/test",
            "/demo/test",
            None,
            "http://localhost:7860",
        ),
        (
            "http://localhost:7860/demo/test/",
            "/demo/test",
            None,
            "http://localhost:7860",
        ),
        (
            "http://localhost:7860/demo/test?query=1",
            "/demo/test",
            None,
            "http://localhost:7860",
        ),
        (
            "http://localhost:7860/demo/test?query=1",
            "/demo/test/",
            "/gradio/",
            "http://localhost:7860/gradio",
        ),
        (
            "http://localhost:7860/demo/test?query=1",
            "/demo/test",
            "/gradio/",
            "http://localhost:7860/gradio",
        ),
        (
            "https://localhost:7860/demo/test?query=1",
            "/demo/test",
            "/gradio/",
            "https://localhost:7860/gradio",
        ),
        (
            "https://www.gradio.app/playground/",
            "/",
            "/playground",
            "https://www.gradio.app/playground",
        ),
        (
            "https://www.gradio.app/playground/",
            "/",
            "/playground",
            "https://www.gradio.app/playground",
        ),
        (
            "https://www.gradio.app/playground/",
            "/",
            "",
            "https://www.gradio.app/playground",
        ),
        (
            "https://www.gradio.app/playground/",
            "/",
            "http://www.gradio.app/",
            "http://www.gradio.app",
        ),
    ],
)
def test_get_root_url(
    request_url: str, route_path: str, root_path: str, expected_root_url: str
):
    scope = {
        "type": "http",
        "headers": [],
        "path": request_url,
    }
    request = Request(scope)
    assert get_root_url(request, route_path, root_path) == expected_root_url


@pytest.mark.parametrize(
    "headers, root_path, expected_root_url",
    [
        ({}, "/gradio/", "http://gradio.app/gradio"),
        ({"x-forwarded-proto": "http"}, "/gradio/", "http://gradio.app/gradio"),
        ({"x-forwarded-proto": "https"}, "/gradio/", "https://gradio.app/gradio"),
        ({"x-forwarded-host": "gradio.dev"}, "/gradio/", "http://gradio.dev/gradio"),
        (
            {"x-forwarded-host": "gradio.dev", "x-forwarded-proto": "https"},
            "/",
            "https://gradio.dev",
        ),
        (
            {"x-forwarded-host": "gradio.dev", "x-forwarded-proto": "https"},
            "http://google.com",
            "http://google.com",
        ),
    ],
)
def test_get_root_url_headers(
    headers: Dict[str, str], root_path: str, expected_root_url: str
):
    scope = {
        "type": "http",
        "headers": [(k.encode(), v.encode()) for k, v in headers.items()],
        "path": "http://gradio.app",
    }
    request = Request(scope)
    assert get_root_url(request, "/", root_path) == expected_root_url


class TestSimpleAPIRoutes:
    def get_demo(self):
        with Blocks() as demo:
            input = Textbox()
            output = Textbox()
            output2 = Textbox()

            def fn_1(x):
                return f"Hello, {x}!"

            def fn_2(x):
                for i in range(len(x)):
                    time.sleep(0.5)
                    yield f"Hello, {x[:i+1]}!"
                if len(x) < 3:
                    raise ValueError("Small input")

            def fn_3():
                return "a", "b"

            btn1, btn2, btn3 = Button(), Button(), Button()
            btn1.click(fn_1, input, output, api_name="fn1")
            btn2.click(fn_2, input, output2, api_name="fn2")
            btn3.click(fn_3, None, [output, output2], api_name="fn3")
        return demo

    def test_successful_simple_route(self):
        demo = self.get_demo()
        demo.launch(prevent_thread_lock=True)

        response = requests.post(f"{demo.local_url}call/fn1", json={"data": ["world"]})

        assert response.status_code == 200, "Failed to call fn1"
        response = response.json()
        event_id = response["event_id"]

        output = []
        response = requests.get(f"{demo.local_url}call/fn1/{event_id}", stream=True)

        for line in response.iter_lines():
            if line:
                output.append(line.decode("utf-8"))

        assert output == ["event: complete", 'data: ["Hello, world!"]']

        response = requests.post(f"{demo.local_url}call/fn3", json={"data": []})

        assert response.status_code == 200, "Failed to call fn3"
        response = response.json()
        event_id = response["event_id"]

        output = []
        response = requests.get(f"{demo.local_url}call/fn3/{event_id}", stream=True)

        for line in response.iter_lines():
            if line:
                output.append(line.decode("utf-8"))

        assert output == ["event: complete", 'data: ["a", "b"]']

    def test_generative_simple_route(self):
        demo = self.get_demo()
        demo.launch(prevent_thread_lock=True)

        response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["world"]})

        assert response.status_code == 200, "Failed to call fn2"
        response = response.json()
        event_id = response["event_id"]

        output = []
        response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)

        for line in response.iter_lines():
            if line:
                output.append(line.decode("utf-8"))

        assert output == [
            "event: generating",
            'data: ["Hello, w!"]',
            "event: generating",
            'data: ["Hello, wo!"]',
            "event: generating",
            'data: ["Hello, wor!"]',
            "event: generating",
            'data: ["Hello, worl!"]',
            "event: generating",
            'data: ["Hello, world!"]',
            "event: complete",
            'data: ["Hello, world!"]',
        ]

        response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["w"]})

        assert response.status_code == 200, "Failed to call fn2"
        response = response.json()
        event_id = response["event_id"]

        output = []
        response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)

        for line in response.iter_lines():
            if line:
                output.append(line.decode("utf-8"))

        assert output == [
            "event: generating",
            'data: ["Hello, w!"]',
            "event: error",
            "data: null",
        ]


def test_compare_passwords_securely():
    password1 = "password"
    password2 = "pässword"
    assert compare_passwords_securely(password1, password1)
    assert not compare_passwords_securely(password1, password2)
    assert compare_passwords_securely(password2, password2)


@pytest.mark.parametrize(
    "string, expected",
    [
        ("http://localhost:7860/", True),
        ("https://localhost:7860/", True),
        ("ftp://localhost:7860/", True),
        ("smb://example.com", True),
        ("ipfs://QmTzQ1Nj5R9BzF1djVQv8gvzZxVkJb1vhrLcXL1QyJzZE", True),
        ("usr/local/bin", False),
        ("localhost:7860", False),
        ("localhost", False),
        ("C:/Users/username", False),
        ("//path", True),
        ("\\\\path", True),
        ("/usr/bin//test", False),
        ("/\\10.0.225.200/share", True),
        ("\\/10.0.225.200/share", True),
        ("/home//user", False),
        ("C:\\folder\\file", False),
    ],
)
def test_starts_with_protocol(string, expected):
    assert starts_with_protocol(string) == expected
